{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1186fb8",
   "metadata": {},
   "source": [
    "### P051 随机森林 - 训练随机森林分类模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ef6fac26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e108faae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "62686e53",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_moons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6bab9f5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f012498c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e565e57f",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "496f47cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = make_moons(n_samples=2000, noise=0.25, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0507c4a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data, target = raw_data[0], raw_data[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dadd6a19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((2000, 2), (2000,))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape, target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "529aa7b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train, x_test, y_train, y_test = train_test_split(data, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "034ec6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifer = RandomForestClassifier(random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b542ea92",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestClassifier(random_state=42)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4bbcb477",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = classifer.score(x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e8e0b15e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.93"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e124079b",
   "metadata": {},
   "source": [
    "### P052 随机森林 - 网格搜索获取最优参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1ce21643",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_grid = {\n",
    "    \"criterion\" : [\"gini\", \"entropy\"],\n",
    "    \"max_depth\" : [6, 7, 8],\n",
    "    \"min_samples_leaf\" : [4, 5]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d2155bdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "942282e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search = GridSearchCV(classifer, param_grid=param_grid,\n",
    "                          n_jobs=-1, scoring='accuracy', cv=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4d409939",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=2, estimator=RandomForestClassifier(random_state=42), n_jobs=-1,\n",
       "             param_grid={'criterion': ['gini', 'entropy'],\n",
       "                         'max_depth': [6, 7, 8], 'min_samples_leaf': [4, 5]},\n",
       "             scoring='accuracy')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_search.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9c6842db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.932"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc = grid_search.score(x_test, y_test)\n",
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "04d21170",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'criterion': 'gini', 'max_depth': 8, 'min_samples_leaf': 4}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c3c9757",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
